// Copyright 2015 The LUCI Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package authdbimpl

import (
	"bytes"
	"context"
	"crypto/sha256"
	"encoding/hex"
	"fmt"
	"strings"
	"time"

	ds "go.chromium.org/luci/gae/service/datastore"

	"go.chromium.org/luci/common/clock"
	"go.chromium.org/luci/common/errors"
	"go.chromium.org/luci/common/logging"
	"go.chromium.org/luci/common/retry/transient"
	"go.chromium.org/luci/server/auth/service"
	"go.chromium.org/luci/server/auth/service/protocol"
)

// maxShardSize is a limit on a blob size to store in a single entity.
const maxShardSize = 1020 * 1024 // 1020 KiB

// SnapshotInfo identifies some concrete AuthDB snapshot.
//
// Singleton entity. Serves as a pointer to a blob with corresponding AuthDB
// proto message (stored in separate Snapshot entity).
type SnapshotInfo struct {
	AuthServiceURL string `gae:",noindex"`
	Rev            int64  `gae:",noindex"`

	_kind string `gae:"$kind,gaeauth.SnapshotInfo"`
	_id   int64  `gae:"$id,1"`
}

// GetSnapshotID returns datastore ID of the corresponding Snapshot entity.
func (si *SnapshotInfo) GetSnapshotID() string {
	if strings.IndexByte(si.AuthServiceURL, ',') != -1 {
		panic(fmt.Errorf("forbidden symbol ',' in URL %q", si.AuthServiceURL))
	}
	return fmt.Sprintf("v1,%s,%d", si.AuthServiceURL, si.Rev)
}

// Snapshot is serialized deflated AuthDB blob with some minimal metadata.
//
// Root entity. Immutable. Key has the form "v1,<AuthServiceURL>,<Revision>",
// it's generated by SnapshotInfo.GetSnapshotID(). It is globally unique
// version identifier, since it includes URL of an auth service. AuthServiceURL
// should be not very long (~< 250 chars) for this too work.
//
// Currently does not get garbage collected.
type Snapshot struct {
	ID string `gae:"$id"`

	// AuthDBDeflated is zlib-compressed serialized AuthDB protobuf message.
	//
	// If it is too big, it is stored in a bunch of SnapshotShard entities
	// referenced by ShardIDs field below.
	//
	// Note: if the old version of this code tries to load a new Snapshot entity
	// with ShardIDs field populated, it would abort with an error because old
	// code doesn't know about ShardIDs field (it is not in the old Snapshot
	// entity struct). This is desirable: the new sharded data structure is not
	// (and can't be made) compatible with old code, so it is good that it breaks
	// as soon as possible.
	AuthDBDeflated []byte `gae:",noindex"`

	// ShardIDs is a list of IDs of SnapshotShard entities to fetch.
	ShardIDs []string `gae:",noindex"`

	CreatedAt time.Time // when it was created on Auth service
	FetchedAt time.Time // when it was fetched and put into the datastore

	_kind string `gae:"$kind,gaeauth.Snapshot"`
}

// SnapshotShard holds a shard of a deflated AuthDB.
type SnapshotShard struct {
	// ID is "<Snapshot ID>:<shard hash>".
	ID string `gae:"$id"`
	// Shard is the actual data.
	Shard []byte `gae:",noindex"`

	_kind string `gae:"$kind,gaeauth.SnapshotShard"`
}

// GetLatestSnapshotInfo fetches SnapshotInfo singleton entity.
//
// If no such entity is stored, returns (nil, nil).
func GetLatestSnapshotInfo(ctx context.Context) (*SnapshotInfo, error) {
	report := durationReporter(ctx, latestSnapshotInfoDuration)
	logging.Debugf(ctx, "Fetching AuthDB snapshot info from the datastore")
	ctx = ds.WithoutTransaction(defaultNS(ctx))
	info := SnapshotInfo{}
	switch err := ds.Get(ctx, &info); {
	case err == ds.ErrNoSuchEntity:
		report("SUCCESS")
		return nil, nil
	case err != nil:
		report("ERROR_TRANSIENT")
		return nil, transient.Tag.Apply(err)
	default:
		report("SUCCESS")
		return &info, nil
	}
}

// deleteSnapshotInfo removes SnapshotInfo entity from the datastore.
//
// Used to detach the service from auth_service.
func deleteSnapshotInfo(ctx context.Context) error {
	ctx = ds.WithoutTransaction(ctx)
	return ds.Delete(ctx, ds.KeyForObj(ctx, &SnapshotInfo{}))
}

// GetAuthDBSnapshot fetches, inflates and deserializes AuthDB snapshot.
func GetAuthDBSnapshot(ctx context.Context, id string) (*protocol.AuthDB, error) {
	report := durationReporter(ctx, getSnapshotDuration)
	logging.Debugf(ctx, "Fetching AuthDB snapshot from the datastore")
	defer logging.Debugf(ctx, "AuthDB snapshot fetched")

	blob, code, err := fetchDeflated(ctx, id)
	if err != nil {
		report(code)
		return nil, err
	}

	db, err := service.InflateAuthDB(blob)
	if err != nil {
		report("ERROR_INFLATION")
		return nil, err
	}

	report("SUCCESS")
	return db, nil
}

// fetchDeflated fetches a deflated AuthDB from datastore, perhaps reassembling
// it from shards.
//
// See also storeDeflated.
func fetchDeflated(ctx context.Context, id string) (blob []byte, code string, err error) {
	ctx = ds.WithoutTransaction(defaultNS(ctx))

	snap := Snapshot{ID: id}

	switch err = ds.Get(ctx, &snap); {
	case err == ds.ErrNoSuchEntity:
		return nil, "ERROR_NO_SNAPSHOT", err // not transient
	case err != nil:
		return nil, "ERROR_TRANSIENT", transient.Tag.Apply(err)
	}

	if len(snap.ShardIDs) != 0 {
		logging.Infof(ctx, "Reconstructing from %d shards", len(snap.ShardIDs))
		switch snap.AuthDBDeflated, err = unshardAuthDB(ctx, snap.ShardIDs); {
		case transient.Tag.In(err):
			return nil, "ERROR_SHARDS_TRANSIENT", err
		case err != nil:
			return nil, "ERROR_SHARDS_MISSING", err
		}
	}

	return snap.AuthDBDeflated, "SUCCESS", nil
}

// ConfigureAuthService makes initial fetch of AuthDB snapshot from the auth
// service and sets up PubSub subscription.
//
// `baseURL` is root URL of currently running service, will be used to derive
// PubSub push endpoint URL.
//
// If `authServiceURL` is blank, disables the fetching.
func ConfigureAuthService(ctx context.Context, baseURL, authServiceURL string) error {
	logging.Infof(ctx, "Reconfiguring AuthDB to be fetched from %q", authServiceURL)
	ctx = defaultNS(ctx)

	// If switching auth services, need to grab URL of a currently configured
	// auth service to unsubscribe from its PubSub stream.
	prevAuthServiceURL := ""
	switch existing, err := GetLatestSnapshotInfo(ctx); {
	case err != nil:
		return err
	case existing != nil:
		prevAuthServiceURL = existing.AuthServiceURL
	}

	// Stopping synchronization completely?
	if authServiceURL == "" {
		if prevAuthServiceURL != "" {
			if err := killPubSub(ctx, prevAuthServiceURL); err != nil {
				return err
			}
		}
		return deleteSnapshotInfo(ctx)
	}

	// Fetch latest AuthDB snapshot and store it in the datastore, thus verifying
	// authServiceURL works end-to-end.
	srv := getAuthService(ctx, authServiceURL)
	latestRev, err := srv.GetLatestSnapshotRevision(ctx)
	if err != nil {
		return err
	}
	info := &SnapshotInfo{
		AuthServiceURL: authServiceURL,
		Rev:            latestRev,
	}
	if err := fetchSnapshot(ctx, info); err != nil {
		logging.Errorf(ctx, "Failed to fetch latest snapshot from %s - %s", authServiceURL, err)
		return err
	}

	// Configure PubSub subscription to receive future updates.
	if err := setupPubSub(ctx, baseURL, authServiceURL); err != nil {
		logging.Errorf(ctx, "Failed to configure pubsub subscription - %s", err)
		return err
	}

	// All is configured. Switch SnapshotInfo entity to point to new snapshot.
	// It makes syncAuthDB fetch changes from `authServiceURL`, thus promoting
	// `authServiceURL` to the status of main auth service.
	if err := ds.Put(ds.WithoutTransaction(ctx), info); err != nil {
		return transient.Tag.Apply(err)
	}

	// Stop getting notifications from previously used auth service.
	if prevAuthServiceURL != "" && prevAuthServiceURL != authServiceURL {
		return killPubSub(ctx, prevAuthServiceURL)
	}

	return nil
}

// fetchSnapshot fetches AuthDB snapshot specified by `info` and puts it into
// the datastore.
//
// Idempotent. Doesn't touch SnapshotInfo entity itself, and thus always safe
// to call.
func fetchSnapshot(ctx context.Context, info *SnapshotInfo) error {
	srv := getAuthService(ctx, info.AuthServiceURL)
	snap, err := srv.GetSnapshot(ctx, info.Rev)
	if err != nil {
		return err
	}
	blob, err := service.DeflateAuthDB(snap.AuthDB)
	if err != nil {
		return err
	}
	if err := storeDeflated(ctx, info.GetSnapshotID(), blob, snap.Created, maxShardSize); err != nil {
		return err
	}
	logging.Infof(ctx, "Lag: %s", clock.Now(ctx).Sub(snap.Created))
	return nil
}

// storeDeflated stores a deflated AuthDB into datastore, perhaps splitting it
// into shards.
//
// See also fetchDeflated.
func storeDeflated(ctx context.Context, id string, blob []byte, created time.Time, maxShardSize int) error {
	ctx = ds.WithoutTransaction(defaultNS(ctx))

	snapshot := Snapshot{
		ID:        id,
		CreatedAt: created.UTC(),
		FetchedAt: clock.Now(ctx).UTC(),
	}

	// If we are able to store AuthDB inline in the Snapshot, do it. That way
	// older versions of this code can still successfully read it. If it doesn't
	// fit, there's nothing we can do other than to store it separately in shards.
	// The old code will see unrecognized ShardIDs field and will fail.
	if len(blob) < maxShardSize {
		snapshot.AuthDBDeflated = blob
	} else {
		var err error
		if snapshot.ShardIDs, err = shardAuthDB(ctx, id, blob, maxShardSize); err != nil {
			return err
		}
		logging.Infof(ctx, "Split into %d shards", len(snapshot.ShardIDs))
	}

	return transient.Tag.Apply(ds.Put(ctx, &snapshot))
}

// syncAuthDB fetches latest AuthDB snapshot from the configured auth service,
// puts it into the datastore and updates SnapshotInfo entity to point to it.
//
// Expects authenticating transport to be in the context. Called when receiving
// PubSub notifications.
//
// Returns SnapshotInfo of the most recent snapshot.
func syncAuthDB(ctx context.Context) (*SnapshotInfo, error) {
	report := durationReporter(ctx, syncAuthDBDuration)

	// `info` is what we have in the datastore now.
	info, err := GetLatestSnapshotInfo(ctx)
	if err != nil {
		report("ERROR_GET_LATEST_INFO")
		return nil, err
	}
	if info == nil {
		report("ERROR_NOT_CONFIGURED")
		return nil, errors.New("auth_service URL is not configured")
	}

	// Grab revision number of the latest snapshot on the server.
	srv := getAuthService(ctx, info.AuthServiceURL)
	latestRev, err := srv.GetLatestSnapshotRevision(ctx)
	if err != nil {
		report("ERROR_GET_LATEST_REVISION")
		return nil, err
	}

	// Nothing new?
	if info.Rev == latestRev {
		logging.Infof(ctx, "AuthDB is up-to-date at revision %d", latestRev)
		report("SUCCESS_UP_TO_DATE")
		return info, nil
	}

	// Auth service traveled back in time?
	if info.Rev > latestRev {
		logging.Errorf(
			ctx, "Latest AuthDB revision on server is %d, we have %d. It should not happen",
			latestRev, info.Rev)
		report("SUCCESS_NEWER_ALREADY")
		return info, nil
	}

	// Fetch the actual snapshot from the server and put it into the datastore.
	info.Rev = latestRev
	if err = fetchSnapshot(ctx, info); err != nil {
		logging.Errorf(ctx, "Failed to fetch snapshot %d from %q - %s", info.Rev, info.AuthServiceURL, err)
		report("ERROR_FETCHING")
		return nil, err
	}

	// Move pointer to the latest snapshot only if it is more recent than what is
	// already in the datastore.
	var latest *SnapshotInfo
	err = ds.RunInTransaction(ds.WithoutTransaction(ctx), func(ctx context.Context) error {
		latest = &SnapshotInfo{}
		switch err := ds.Get(ctx, latest); {
		case err == ds.ErrNoSuchEntity:
			logging.Warningf(ctx, "No longer need to fetch AuthDB, not configured anymore")
			return nil
		case err != nil:
			return err
		case latest.AuthServiceURL != info.AuthServiceURL:
			logging.Warningf(
				ctx, "No longer need to fetch AuthDB from %q, %q is primary now",
				info.AuthServiceURL, latest.AuthServiceURL)
			return nil
		case latest.Rev >= info.Rev:
			logging.Warningf(ctx, "Already have rev %d", info.Rev)
			return nil
		}
		latest = info
		return ds.Put(ctx, info)
	}, nil)

	if err != nil {
		report("ERROR_COMMITTING")
		return nil, transient.Tag.Apply(err)
	}

	report("SUCCESS_UPDATED")
	return latest, nil
}

// shardAuthDB splits an AuthDB blob into multiple SnapshotShard entities.
func shardAuthDB(ctx context.Context, id string, blob []byte, maxSize int) ([]string, error) {
	var ids []string

	var shard []byte
	for len(blob) != 0 {
		shardSize := maxSize
		if shardSize > len(blob) {
			shardSize = len(blob)
		}
		shard, blob = blob[:shardSize], blob[shardSize:]

		digest := sha256.Sum256(shard)
		shardID := fmt.Sprintf("%s:%s", id, hex.EncodeToString(digest[:]))
		ids = append(ids, shardID)

		// Store shards sequentially to avoid allocating RAM to store full `blob` in
		// RPC buffers. There's no requirement for this code to be performant, it
		// executes in a background job.
		err := ds.Put(ctx, &SnapshotShard{ID: shardID, Shard: shard})
		if err != nil {
			return nil, transient.Tag.Apply(err)
		}
	}

	return ids, nil
}

// unshardAuthDB fetches SnapshotShard entities and reassembles the AuthDB blob.
func unshardAuthDB(ctx context.Context, shardIDs []string) ([]byte, error) {
	shards := make([]SnapshotShard, len(shardIDs))
	for idx, id := range shardIDs {
		shards[idx].ID = id
	}

	if err := ds.Get(ctx, shards); err != nil {
		if merr, ok := err.(errors.MultiError); ok {
			for _, inner := range merr {
				if inner == ds.ErrNoSuchEntity {
					return nil, err // fatal
				}
			}
			return nil, transient.Tag.Apply(err)
		} else {
			// Overall RPC error.
			return nil, transient.Tag.Apply(err)
		}
	}

	slices := make([][]byte, len(shards))
	for idx, shard := range shards {
		slices[idx] = shard.Shard
	}
	return bytes.Join(slices, nil), nil
}
